{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "72dbaafb",
   "metadata": {},
   "source": [
    "# Chapter 12 Training A Transformer to Generate Text\n",
    "\n",
    "This chapter covers\n",
    "\n",
    "* Building a scaled-down version of the GPT-2XL model tailored to your needs\n",
    "* Preparing training data for training a GPT-style Transformer\n",
    "* Training a GPT-style Transformer from scratch\n",
    "* Generating text using the trained GPT model\n",
    "\n",
    "In Chapter 11, we developed the GPT-2XL model from scratch but were unable to train it due to its vast number of parameters. Training a model with 1.5 billion parameters requires supercomputing facilities and an enormous amount of data. Consequently, we loaded pre-trained weights from OpenAI into our model and then used the GPT-2XL model to generate text.\n",
    "\n",
    "However, Learning how to train a Transformer model from scratch is crucial for several reasons. First, while this book doesn't directly cover fine-tuning a pre-rained model, understanding how to train a Transformer equips you with the skills needed for fine-tuning. Training a model involves initializing parameters randomly, whereas fine-tuning involves loading pre-trained weights and further training the model. Second, training or fine-tuning a Transformer enables you to customize the model to meet your specific needs and domain, which can significantly enhance its performance and relevance for your use case. Finally, training your own Transformer or fine-tuning an existing one provides greater control over data and privacy, which is particularly important for sensitive applications or handling proprietary data. In summary, mastering the training and fine-tuning of Transformers is essential for anyone looking to harness the power of language models for specific applications while maintaining privacy and control. \n",
    "\n",
    "Therefore, in this chapter, we’ll construct a scaled-down version of the GPT model with approximately five million parameters. This smaller model follows the architecture of the GPT-2XL model, with significant differences being its composition of only three decoder blocks and an embedding dimension of 256, compared to the original's 48 decoder blocks and an embedding dimension of 1600. By scaling down the GPT model to about 5 million parameters, we can train it on a regular computer. \n",
    "\n",
    "The generated text's style will depend on the training data. When training a model from scratch for text generation, both text length and variation are crucial. The training material must be extensive enough for the model to learn and mimic a particular writing style effectively. At the same time, if the training material lacks variation, the model may simply replicate passages from the training text. On the other hand, if the material is too long, training may require excessive computational resources. Therefore, we will use three novels by Ernest Hemingway as our training material: The Old Man and the Sea, A Farewell to Arms, and For Whom the Bell Tolls. This selection ensures that our training data has sufficient length and variation for effective learning, without being so long that training becomes impractical.\n",
    "\n",
    "Since GPT models cannot process raw text directly, we will first tokenize the text into words. We will then create a dictionary to map each unique token to an index. Using this dictionary, we will convert the text into a long sequence of integers, ready for input into a neural network.\n",
    "\n",
    "We will use sequences of 128 indexes as input to train the GPT model. As in Chapters 8 and 10, we will shift the input sequence by one token to the right and use it as the output. This approach forces the model to predict the next word in a sentence based on the current token and all previous tokens in the sequence.\n",
    "\n",
    "A key challenge is determining the optimal number of epochs for training the model. Our goal is not merely to minimize the cross-entropy loss, as doing so could lead to overfitting, where the model simply replicates passages from the training text. To tackle this issue, we plan to train the model for 40 epochs. We will save the model at ten-epoch intervals and evaluate which version can generate coherent text without merely copying passages from the training material. Alternatively, one could potentially use a validation set to assess the performance of the model and decide when to stop training, as we did in Chapter 2. \n",
    "\n",
    "Once our GPT model is trained, we will use it to generate text autoregressively, as we did in Chapter 11. We’ll test different versions of the trained model. The model trained for 40 epochs produces very coherent text, capturing Hemingway's distinctive style. However, it may also generate text partly copied from the training material, especially if the prompt is similar to passages in the training text. The model trained for 20 epochs also generates coherent text, albeit with occasional grammatical errors, but is less likely to directly copy from the training text.\n",
    "\n",
    "The primary goal of this chapter is not necessarily to generate the most coherent text possible, which presents significant challenges. Instead, our objective is to teach you how to build a GPT-style model from scratch, tailored to real-world applications and your specific needs. More importantly, this chapter outlines the steps involved in training a GPT model from scratch. You will learn how to select training text based on your objectives, tokenize the text and convert it to indexes, and prepare batches of training data. You will also learn how to determine the number of epochs for training. Once the model is trained, you will learn how to generate text using the model and how to avoid generating text directly copied from the training material.\n",
    "\n",
    "# 1.\tHow to build and train a GPT from scratch?\n",
    "# 2.\tTokenize text of Hemingway novels\n",
    "## 2.1. \tTokenize the text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "04a47b14",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"files/OldManAndSea.txt\",\"r\", encoding='utf-8-sig') as f:\n",
    "    text=f.read()\n",
    "text=list(text)    #A\n",
    "for i in range(len(text)):\n",
    "    if text[i]=='\"':\n",
    "        if text[i+1]==' ' or text[i+1]=='\\n':\n",
    "            text[i]='”'    #B\n",
    "        if text[i+1]!=' ' and text[i+1]!='\\n':\n",
    "            text[i]='“'    #C\n",
    "    if text[i]==\"'\":\n",
    "        if text[i-1]!=' ' and text[i-1]!='\\n':\n",
    "            text[i]='’'    #D   \n",
    "text=\"\".join(text)    #E"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "daf84181",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "He was an old man who fished alone in a skiff in the Gulf Stream and he\n",
      "had gone eighty-four days now without taking a fish.  In the first\n",
      "forty days a boy had been with him.  But after forty days without a\n",
      "fish the boy’s parents had told him that th\n"
     ]
    }
   ],
   "source": [
    "with open(\"files/ToWhomTheBellTolls.txt\",\"r\", encoding='utf-8-sig') as f:\n",
    "    text1=f.read()    #A\n",
    "\n",
    "with open(\"files/FarewellToArms.txt\",\"r\", encoding='utf-8-sig') as f:\n",
    "    text2=f.read()    #B\n",
    "\n",
    "text=text+\" \"+text1+\" \"+text2    #C\n",
    "\n",
    "with open(\"files/ThreeNovels.txt\",\"w\", \n",
    "          encoding='utf-8-sig') as f:\n",
    "    f.write(text)    #D\n",
    "print(text[:250])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "928a5319",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[')', '.', '&', ':', '(', ';', '-', '!', '“', ' ', '‘', '”', '?', ',', '’']\n",
      "10599\n"
     ]
    }
   ],
   "source": [
    "text=text.lower().replace(\"\\n\", \" \")\n",
    "chars=set(text.lower())\n",
    "punctuations=[i for i in chars if i.isalpha()==False\n",
    "              and i.isdigit()==False]\n",
    "print(punctuations)\n",
    "\n",
    "for x in punctuations:\n",
    "    text=text.replace(f\"{x}\", f\" {x} \")\n",
    "text_tokenized=text.split()\n",
    "\n",
    "unique_tokens=set(text_tokenized)\n",
    "print(len(unique_tokens))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8c3b73ed",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "the text contains 698207 words\n",
      "there are 10600 unique tokens\n",
      "{'.': 0, 'the': 1, ',': 2, '“': 3, '”': 4, 'and': 5, 'i': 6, 'to': 7, 'he': 8, 'it': 9}\n",
      "{0: '.', 1: 'the', 2: ',', 3: '“', 4: '”', 5: 'and', 6: 'i', 7: 'to', 8: 'he', 9: 'it'}\n"
     ]
    }
   ],
   "source": [
    "from collections import Counter   \n",
    "\n",
    "word_counts=Counter(text_tokenized)    \n",
    "words=sorted(word_counts, key=word_counts.get,\n",
    "                      reverse=True)     \n",
    "words.append(\"UNK\")    #A \n",
    "text_length=len(text_tokenized)\n",
    "ntokens=len(words)    #B\n",
    "print(f\"the text contains {text_length} words\")\n",
    "print(f\"there are {ntokens} unique tokens\")  \n",
    "word_to_int={v:k for k,v in enumerate(words)}    #C \n",
    "int_to_word={v:k for k,v in word_to_int.items()}    #D\n",
    "print({k:v for k,v in word_to_int.items() if k in words[:10]})\n",
    "print({k:v for k,v in int_to_word.items() if v in words[:10]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "265e09e4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['he', 'was', 'an', 'old', 'man', 'who', 'fished', 'alone', 'in', 'a', 'skiff', 'in', 'the', 'gulf', 'stream', 'and', 'he', 'had', 'gone', 'eighty']\n",
      "[8, 16, 98, 110, 67, 85, 6052, 314, 14, 11, 1039, 14, 1, 3193, 507, 5, 8, 25, 223, 3125]\n"
     ]
    }
   ],
   "source": [
    "print(text_tokenized[0:20])\n",
    "wordidx=[word_to_int[w] for w in text_tokenized]  \n",
    "print([word_to_int[w] for w in text_tokenized[0:20]])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da23c0b7",
   "metadata": {},
   "source": [
    "## 2.2\tCreate batches for training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4d043291",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "seq_len=128  \n",
    "xys=[]\n",
    "for n in range(0, len(wordidx)-seq_len-1):\n",
    "    x = wordidx[n:n+seq_len]\n",
    "    y = wordidx[n+1:n+seq_len+1]\n",
    "    xys.append((torch.tensor(x),(torch.tensor(y))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "bc72a756",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[   3,  129,    9,  ...,   11,  251,   10],\n",
      "        [   5,   41,   32,  ...,  995,   52,   23],\n",
      "        [   6,   25,   11,  ...,   15,    0,   24],\n",
      "        ...,\n",
      "        [1254,    0,    4,  ...,   15,    0,    3],\n",
      "        [  17,    8, 1388,  ...,    0,    8,   16],\n",
      "        [  55,   20,  156,  ...,   74,   76,   12]])\n",
      "tensor([[ 129,    9,   23,  ...,  251,   10,    1],\n",
      "        [  41,   32,   34,  ...,   52,   23,    1],\n",
      "        [  25,   11,   59,  ...,    0,   24,   25],\n",
      "        ...,\n",
      "        [   0,    4,    3,  ...,    0,    3,   93],\n",
      "        [   8, 1388,    1,  ...,    8,   16, 1437],\n",
      "        [  20,  156,  970,  ...,   76,   12,   29]])\n",
      "torch.Size([32, 128]) torch.Size([32, 128])\n"
     ]
    }
   ],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "\n",
    "torch.manual_seed(42)\n",
    "batch_size=32\n",
    "loader = DataLoader(xys, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "x,y=next(iter(loader))\n",
    "print(x)\n",
    "print(y)\n",
    "print(x.shape,y.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "09abf4ce",
   "metadata": {},
   "source": [
    "# 3\tBuild a GPT to generate text\n",
    "## 3.1\tModel the causal self-attention mechanism"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "92591cf6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "import math\n",
    "\n",
    "device=\"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "class GELU(nn.Module):\n",
    "    def forward(self, x):\n",
    "        return 0.5*x*(1.0+torch.tanh(math.sqrt(2.0/math.pi)*\\\n",
    "                       (x + 0.044715 * torch.pow(x, 3.0))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "12a3100c",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Config():\n",
    "    def __init__(self):\n",
    "        self.n_layer = 3\n",
    "        self.n_head = 4\n",
    "        self.n_embd = 256\n",
    "        self.vocab_size = ntokens\n",
    "        self.block_size = 128 \n",
    "        self.embd_pdrop = 0.1\n",
    "        self.resid_pdrop = 0.1\n",
    "        self.attn_pdrop = 0.1\n",
    "        \n",
    "# instantiate a Config() class\n",
    "config=Config()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "51ce2b84",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "class CausalSelfAttention(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)\n",
    "        self.c_proj = nn.Linear(config.n_embd, config.n_embd)\n",
    "        self.attn_dropout = nn.Dropout(config.attn_pdrop)\n",
    "        self.resid_dropout = nn.Dropout(config.resid_pdrop)\n",
    "        self.register_buffer(\"bias\", torch.tril(torch.ones(\\\n",
    "                   config.block_size, config.block_size))\n",
    "             .view(1, 1, config.block_size, config.block_size))\n",
    "        self.n_head = config.n_head\n",
    "        self.n_embd = config.n_embd\n",
    "\n",
    "    def forward(self, x):\n",
    "        B, T, C = x.size() \n",
    "        q, k ,v  = self.c_attn(x).split(self.n_embd, dim=2)\n",
    "        hs = C // self.n_head\n",
    "        k = k.view(B, T, self.n_head, hs).transpose(1, 2) \n",
    "        q = q.view(B, T, self.n_head, hs).transpose(1, 2) \n",
    "        v = v.view(B, T, self.n_head, hs).transpose(1, 2) \n",
    "\n",
    "        att = (q @ k.transpose(-2, -1)) *\\\n",
    "            (1.0 / math.sqrt(k.size(-1)))\n",
    "        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, \\\n",
    "                              float('-inf'))\n",
    "        att = F.softmax(att, dim=-1)\n",
    "        att = self.attn_dropout(att)\n",
    "        y = att @ v \n",
    "        y = y.transpose(1, 2).contiguous().view(B, T, C)\n",
    "        y = self.resid_dropout(self.c_proj(y))\n",
    "        return y"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e09532a0",
   "metadata": {},
   "source": [
    "## 3.2\tBuild the GPT model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "dbba81a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Block(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.ln_1 = nn.LayerNorm(config.n_embd)\n",
    "        self.attn = CausalSelfAttention(config)\n",
    "        self.ln_2 = nn.LayerNorm(config.n_embd)\n",
    "        self.mlp = nn.ModuleDict(dict(\n",
    "            c_fc   = nn.Linear(config.n_embd, 4 * config.n_embd),\n",
    "            c_proj = nn.Linear(4 * config.n_embd, config.n_embd),\n",
    "            act    = GELU(),\n",
    "            dropout = nn.Dropout(config.resid_pdrop),\n",
    "        ))\n",
    "        m = self.mlp\n",
    "        self.mlpf=lambda x:m.dropout(m.c_proj(m.act(m.c_fc(x)))) \n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x + self.attn(self.ln_1(x))\n",
    "        x = x + self.mlpf(self.ln_2(x))\n",
    "        return x\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b9b00a1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.block_size = config.block_size\n",
    "        self.transformer = nn.ModuleDict(dict(\n",
    "            wte = nn.Embedding(config.vocab_size, config.n_embd),\n",
    "            wpe = nn.Embedding(config.block_size, config.n_embd),\n",
    "            drop = nn.Dropout(config.embd_pdrop),\n",
    "            h = nn.ModuleList([Block(config) \n",
    "                               for _ in range(config.n_layer)]),   \n",
    "            ln_f = nn.LayerNorm(config.n_embd),))\n",
    "        self.lm_head = nn.Linear(config.n_embd,\n",
    "                                 config.vocab_size, bias=False)      \n",
    "        for pn, p in self.named_parameters():\n",
    "            if pn.endswith('c_proj.weight'):    \n",
    "                torch.nn.init.normal_(p, mean=0.0, \n",
    "                  std=0.02/math.sqrt(2 * config.n_layer))\n",
    "    def forward(self, idx, targets=None):\n",
    "        b, t = idx.size()\n",
    "        pos = torch.arange(0,t,dtype=torch.long).unsqueeze(0).to(device)\n",
    "        tok_emb = self.transformer.wte(idx) \n",
    "        pos_emb = self.transformer.wpe(pos) \n",
    "        x = self.transformer.drop(tok_emb + pos_emb)\n",
    "        for block in self.transformer.h:\n",
    "            x = block(x)\n",
    "        x = self.transformer.ln_f(x)\n",
    "        logits = self.lm_head(x)\n",
    "        return logits\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "f03e828a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "number of parameters: 5.12M\n",
      "Model(\n",
      "  (transformer): ModuleDict(\n",
      "    (wte): Embedding(10600, 256)\n",
      "    (wpe): Embedding(128, 256)\n",
      "    (drop): Dropout(p=0.1, inplace=False)\n",
      "    (h): ModuleList(\n",
      "      (0-2): 3 x Block(\n",
      "        (ln_1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
      "        (attn): CausalSelfAttention(\n",
      "          (c_attn): Linear(in_features=256, out_features=768, bias=True)\n",
      "          (c_proj): Linear(in_features=256, out_features=256, bias=True)\n",
      "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
      "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
      "        )\n",
      "        (ln_2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
      "        (mlp): ModuleDict(\n",
      "          (c_fc): Linear(in_features=256, out_features=1024, bias=True)\n",
      "          (c_proj): Linear(in_features=1024, out_features=256, bias=True)\n",
      "          (act): GELU()\n",
      "          (dropout): Dropout(p=0.1, inplace=False)\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "    (ln_f): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
      "  )\n",
      "  (lm_head): Linear(in_features=256, out_features=10600, bias=False)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "model=Model(config)\n",
    "model.to(device)\n",
    "num=sum(p.numel() for p in model.transformer.parameters())\n",
    "print(\"number of parameters: %.2fM\" % (num/1e6,))\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7e0908d2",
   "metadata": {},
   "source": [
    "# 4\tTrain the GPT model to generate text\n",
    "## 4.1\tTrain the GPT model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "1155a389",
   "metadata": {},
   "outputs": [],
   "source": [
    "lr=0.0001\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n",
    "loss_func = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "0f1f335d",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.train()  \n",
    "for i in range(1,41):\n",
    "    tloss = 0.\n",
    "    for idx, (x,y) in enumerate(loader):\n",
    "        x,y=x.to(device),y.to(device)\n",
    "        output = model(x)\n",
    "        loss=loss_func(output.view(-1,output.size(-1)),\n",
    "                           y.view(-1))\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        nn.utils.clip_grad_norm_(model.parameters(),1)\n",
    "        optimizer.step()\n",
    "        tloss += loss.item()\n",
    "    print(f'epoch {i} loss {tloss/(idx+1)}') \n",
    "    if i%10==0:\n",
    "        torch.save(model.state_dict(),f'files/GPTe{i}.pth') "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8c463f9",
   "metadata": {},
   "source": [
    "## 4.2\tA function to generate text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "0350bead",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample(idx, weights, max_new_tokens, temperature=1.0, top_k=None):\n",
    "    model.eval()\n",
    "    model.load_state_dict(torch.load(weights))\n",
    "    # keep track of the length of the original indexes\n",
    "    original_length=len(idx[0])\n",
    "    # add a fixed number of tokens to prompt\n",
    "    for _ in range(max_new_tokens):\n",
    "        # if the text is more than 1024 tokenx, trim it\n",
    "        if idx.size(1) <= config.block_size:\n",
    "            idx_cond = idx  \n",
    "        else:\n",
    "            idx_cond = idx[:, -config.block_size:]\n",
    "        # predict the logits for the index in sequence\n",
    "        logits = model(idx_cond.to(device))\n",
    "        # pluck the logits at the final step; apply temperature \n",
    "        logits = logits[:, -1, :] / temperature\n",
    "        # crop the logits to only the top k options\n",
    "        if top_k is not None:\n",
    "            v, _ = torch.topk(logits, top_k)\n",
    "            logits[logits < v[:, [-1]]] = -float('Inf')\n",
    "        # apply softmax to get probabilities\n",
    "        probs = F.softmax(logits, dim=-1)\n",
    "        idx_next=torch.multinomial(probs,num_samples=1)\n",
    "        idx = torch.cat((idx, idx_next.cpu()), dim=1)\n",
    "    # keep only new tokens\n",
    "    return idx[:, original_length:]  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "cd43802b",
   "metadata": {},
   "outputs": [],
   "source": [
    "UNK=word_to_int[\"UNK\"]\n",
    "def generate(prompt, weights, max_new_tokens, temperature=1.0,\n",
    "             top_k=None):\n",
    "    assert len(prompt)>0, \"prompt must contain at least one token\"\n",
    "    text=prompt.lower().replace(\"\\n\", \" \")\n",
    "    for x in punctuations:\n",
    "        text=text.replace(f\"{x}\", f\" {x} \")\n",
    "    text_tokenized=text.split() \n",
    "    idx=[word_to_int.get(w,UNK) for w in text_tokenized]\n",
    "    idx=torch.LongTensor(idx).unsqueeze(0)\n",
    "    # add a fixed number of tokens to prompt\n",
    "    idx=sample(idx, weights, max_new_tokens, temperature=1.0, top_k=None)\n",
    "    # convert indexes to text\n",
    "    tokens=[int_to_word[i] for i in idx.squeeze().numpy()] \n",
    "    text=\" \".join(tokens)\n",
    "    for x in '''”).:;!?,-‘’''':\n",
    "        text=text.replace(f\" {x}\", f\"{x}\") \n",
    "    for x in '''“(-‘’''':\n",
    "        text=text.replace(f\"{x} \", f\"{x}\")     \n",
    "    return prompt+\" \"+text"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8bb92ae3",
   "metadata": {},
   "source": [
    "# 4.3\tText generation with different versions of the trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "aba56061",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "way.” “kümmel,” i said. “it’s the way to talk about it\n",
      "--------------------------------------------------\n",
      ",” robert jordan said. “but do not realize how far he is ruined.” “pero\n",
      "--------------------------------------------------\n",
      "in the fog, robert jordan thought. and then, without looking at last, so good, he\n",
      "--------------------------------------------------\n",
      "pot of yellow rice and fish and the boy loved him. “no,” the boy said.\n",
      "--------------------------------------------------\n",
      "the line now. it’s wonderful.” “he’s crazy about the brave.”\n",
      "--------------------------------------------------\n",
      "candle to us. “and if the maria kisses thee again i will commence kissing thee myself. it\n",
      "--------------------------------------------------\n",
      "?” “do you have to for the moment.” robert jordan got up and walked away in\n",
      "--------------------------------------------------\n",
      ". a uniform for my father, he thought. i’ll say them later. just then he\n",
      "--------------------------------------------------\n",
      "and more practical to read and relax in the evening; of all the things he had enjoyed the next\n",
      "--------------------------------------------------\n",
      "in bed and rolled himself a cigarette. when he gave them a log to a second grenade. “\n",
      "--------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "prompt=\"UNK\"\n",
    "for i in range(10):\n",
    "    torch.manual_seed(i)\n",
    "    print(generate(prompt,'files/GPTe20.pth',max_new_tokens=20)[4:])\n",
    "    print(\"-\"*50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "10322c56",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "way.” “kümmel, and i will enjoy the killing. they must have brought me a spit\n",
      "--------------------------------------------------\n",
      ",” robert jordan said. “but do not tell me that he saw anything.” “not\n",
      "--------------------------------------------------\n",
      "in the first time he had bit the ear like that and held onto it, his neck and jaws\n",
      "--------------------------------------------------\n",
      "pot of yellow rice with fish. it was cold now in the head and he could not see the\n",
      "--------------------------------------------------\n",
      "the line of his mouth. he thought.” “the laughing hurt him.” “i can\n",
      "--------------------------------------------------\n",
      "candle made? that was the worst day of my life until one other day.” “don’\n",
      "--------------------------------------------------\n",
      "?” “do you have to for the moment.” robert jordan took the glasses and opened the\n",
      "--------------------------------------------------\n",
      ". that’s what they don’t marry.” i reached for her hand. “don\n",
      "--------------------------------------------------\n",
      "and more grenades. that was the last for next year. it crossed the river away from the front\n",
      "--------------------------------------------------\n",
      "in a revolutionary army,” robert jordan said. “that’s really nonsense. it’s\n",
      "--------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "prompt=\"UNK\"\n",
    "for i in range(10):\n",
    "    torch.manual_seed(i)\n",
    "    print(generate(prompt,'files/GPTe40.pth',max_new_tokens=20)[4:])\n",
    "    print(\"-\"*50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "06a5f72b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ". i know that the doctor was a doctor who would stay with a beard and wore a red scar across the room. “how is she in the legs?” “she could not come,” i said. “you can go too. if\n"
     ]
    }
   ],
   "source": [
    "# Answer to exercis 12.1\n",
    "prompt=\"UNK\"\n",
    "torch.manual_seed(42)\n",
    "print(generate(prompt,'files/GPTe10.pth',max_new_tokens=50)[4:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "d3ed269b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "the old man saw the shark near the old man’s head with his tail out and the old man hit him squarely in the center of\n",
      "--------------------------------------------------\n",
      "the old man saw the shark near the boat with one hand. he had no feeling of the morning but he started to pull on it gently\n",
      "--------------------------------------------------\n",
      "the old man saw the shark near the old man’s head. then he went back to another man in and leaned over and dipped the\n",
      "--------------------------------------------------\n",
      "the old man saw the shark near the fish now, and the old man was asleep in the water as he rowed he was out of the\n",
      "--------------------------------------------------\n",
      "the old man saw the shark near the boat. it was a nice-boat. he saw the old man’s head and he started\n",
      "--------------------------------------------------\n",
      "the old man saw the shark near the boat to see him clearly and he was afraid that he was higher out of the water and the old\n",
      "--------------------------------------------------\n",
      "the old man saw the shark near the old man’s head and then, with his tail lashing and his jaws clicking, the shark plowed\n",
      "--------------------------------------------------\n",
      "the old man saw the shark near the line with his tail which was not sweet smelling it. the old man knew that the fish was coming\n",
      "--------------------------------------------------\n",
      "the old man saw the shark near the fish with his jaws hooked and the old man stabbed him in his left eye. the shark still hung\n",
      "--------------------------------------------------\n",
      "the old man saw the shark near the fish and he started to shake his head again. the old man was asleep in the stern and he\n",
      "--------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "prompt=\"the old man saw the shark near the\"\n",
    "for i in range(10):\n",
    "    torch.manual_seed(i)\n",
    "    print(generate(prompt,'files/GPTe40.pth',max_new_tokens=20))\n",
    "    print(\"-\"*50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "ea157ffe",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "the old man saw the shark near the boat. then he swung the great fish that was more comfortable in the sun. the old man could\n",
      "--------------------------------------------------\n",
      "the old man saw the shark near the boat with one hand. he wore his overcoat and carried the submachine gun muzzle down, carrying it in\n",
      "--------------------------------------------------\n",
      "the old man saw the shark near the boat with its long dip sharply and the old man stabbed him in the morning. he could not see\n",
      "--------------------------------------------------\n",
      "the old man saw the shark near the fish that was now heavy and long and grave he had taken no part in. he was still under\n",
      "--------------------------------------------------\n",
      "the old man saw the shark near the boat. it was a nice little light. then he rowed out and the old man was asleep over\n",
      "--------------------------------------------------\n",
      "the old man saw the shark near the boat to come. “old man’s shack and i’ll fill the water with him in\n",
      "--------------------------------------------------\n",
      "the old man saw the shark near the boat and then rose with his lines close him over the stern. “no,” the old man\n",
      "--------------------------------------------------\n",
      "the old man saw the shark near the line with his tail go under. he was cutting away onto the bow and his face was just a\n",
      "--------------------------------------------------\n",
      "the old man saw the shark near the fish with his tail that he swung him in. the shark’s head was out of water and\n",
      "--------------------------------------------------\n",
      "the old man saw the shark near the boat and he started to cry. he could almost have them come down and whipped him in again.\n",
      "--------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "prompt=\"the old man saw the shark near the\"\n",
    "for i in range(10):\n",
    "    torch.manual_seed(i)\n",
    "    print(generate(prompt,'files/GPTe20.pth',max_new_tokens=20,\n",
    "                  temperature=0.9,top_k=50))\n",
    "    print(\"-\"*50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "0c80a275",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "the old man saw the shark near the old man’s head. then he went back inside the skiff and rested on the line as he leaned over the side-side and washed the flying fish in the water, noting the speed of the water against his hand. his hand was phosphorescent from\n"
     ]
    }
   ],
   "source": [
    "# answer to exercise 12.2\n",
    "prompt=\"the old man saw the shark near the\"\n",
    "torch.manual_seed(42)\n",
    "print(generate(prompt,'files/GPTe40.pth',max_new_tokens=50,\n",
    "                  temperature=0.95,top_k=100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "813ce7da",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
